import argparse
import random
import numpy as np
import os
import torch
import torch.nn.functional as F
from scipy.signal import find_peaks
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from decord import VideoReader, cpu
import pickle
import json
import warnings

warnings.filterwarnings("ignore")


class VideoSceneAnalyzer:
    """Video Scene Analyzer based on frame similarity"""
    def __init__(self, window_size=3, min_scene_length=5):
        self.window_size = window_size
        self.min_scene_length = min_scene_length

    def compute_weighted_similarity(self, features):
        """Compute triangle-windowed inter-frame similarity"""
        num_frames = features.shape[0]
        tensor_feats = torch.from_numpy(features).float() if not isinstance(features, torch.Tensor) else features
        norm_feats = F.normalize(tensor_feats, p=2, dim=1)
        scores = []
        for i in range(num_frames):
            sim, total_w = 0., 0.
            for offset in range(-self.window_size // 2, self.window_size // 2 + 1):
                j = i + offset
                if offset == 0 or j < 0 or j >= num_frames:
                    continue
                weight = max(0, self.window_size // 2 + 1 - abs(offset))
                s = torch.cosine_similarity(norm_feats[i].unsqueeze(0), norm_feats[j].unsqueeze(0)).item()
                sim += weight * s
                total_w += weight
            scores.append(sim / total_w if total_w > 0 else 1.)
        return np.array(scores)

    def find_boundaries(self, similarity_scores):
        """Find scene boundaries (local minima of similarity)"""
        minima, _ = find_peaks(-similarity_scores, distance=self.min_scene_length)
        return sorted(minima.tolist())

    def scene_similarity(self, features, scenes):
        """Scene-to-scene similarity matrix (mean pooled features)"""
        scene_feats = np.array([np.mean(features[start:end], axis=0) for start, end in scenes])
        return cosine_similarity(scene_feats)

    def merge_scenes(self, scenes, dinov2_features, target_scene_num):
        """Merge adjacent scenes with highest similarity until target number achieved"""
        curr = scenes.copy()
        while len(curr) > target_scene_num:
            sim_matrix = self.scene_similarity(dinov2_features, curr)
            idx = max(
                ((i, sim_matrix[i, i + 1]) for i in range(len(curr) - 1)),
                key=lambda x: x[1], default=(None, -1)
            )[0]
            if idx is None: break
            merged = (curr[idx][0], curr[idx + 1][1])
            curr = curr[:idx] + [merged] + curr[idx + 2:]
        return curr

    def analyze(self, dinov2_features, target_scene_num):
        """Full scene segmentation pipeline"""
        sim_scores = self.compute_weighted_similarity(dinov2_features)
        boundaries = self.find_boundaries(sim_scores)
        scenes = [(0, boundaries[0])] if boundaries else []
        for i in range(1, len(boundaries)):
            scenes.append((boundaries[i - 1], boundaries[i]))
        scenes.append((boundaries[-1] if boundaries else 0, dinov2_features.shape[0]))
        if len(scenes) > target_scene_num:
            scenes = self.merge_scenes(scenes, dinov2_features, target_scene_num)
        return scenes


class QueryKeyFrameSelector:
    """Query-specific Keyframe Selector"""
    def __init__(self, alpha):
        self.alpha = alpha

    def select_scene_keyframes(self, scenes, itm_scores):
        """Select frame with highest ITM per scene"""
        return [
            start + int(np.argmax(itm_scores[start:end]))
            for start, end in scenes if end > start
        ]

    def mmr_select(self, itm_scores, dinov2_feats, selected, k):
        """MMR-based selection with adaptive threshold"""
        n = len(itm_scores)
        remain = [i for i in range(n) if i not in selected]
        remain = sorted(remain, key=lambda i: itm_scores[i], reverse=True)
        strict, loose, mu, sigma = self._adaptive_thresholds(dinov2_feats, remain, selected)
        delta = 0.05
        threshold = strict
        additional, all_selected = [], selected.copy()
        while len(additional) < k - len(selected) and threshold <= loose:
            for idx in remain:
                if idx in all_selected or len(additional) >= k - len(selected):
                    continue
                f = dinov2_feats[idx:idx+1]
                max_sim = np.max(cosine_similarity(f, dinov2_feats[all_selected])) if all_selected else 0.
                if max_sim < threshold:
                    additional.append(idx)
                    all_selected.append(idx)
            if len(additional) < k - len(selected):
                if threshold >= loose: break
                threshold = min(threshold + delta, loose)
        return additional, mu, sigma

    def _adaptive_thresholds(self, dinov2_feats, remain, selected):
        selected_feats = dinov2_feats[selected]
        max_sims = [
            np.max(cosine_similarity(dinov2_feats[i:i+1], selected_feats)[0])
            for i in remain
        ] if len(selected) > 0 else np.zeros(len(remain))
        mu, sigma = np.mean(max_sims), np.std(max_sims)
        strict = np.clip(mu - self.alpha * sigma, 0., 1.)
        loose = np.clip(mu + self.alpha * sigma, 0., 1.)
        return strict, loose, mu, sigma

    def select(self, scenes, dinov2_feats, itm_scores, k):
        """Select keyframes per-query: best per scene + diversified with MMR"""
        scene_keys = self.select_scene_keyframes(scenes, itm_scores)
        remain = k - len(scene_keys)
        add_keys, mu, sigma = ([], 0, 0)
        if remain > 0:
            add_keys, mu, sigma = self.mmr_select(itm_scores, dinov2_feats, scene_keys, k)
        all_keyframes = sorted(set(scene_keys + add_keys))
        return all_keyframes[:k], mu, sigma


def map_and_ensure_keyframes(sampled_indices, frame_map, target_k, video_path, dinov2_feats):
    """Map selected indices to original, pad with FPS if not enough"""
    mapped = sorted(set(frame_map[i] for i in sampled_indices))
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    total_frames = len(vr)
    if len(mapped) >= target_k:
        return sorted(random.sample(mapped, target_k))
    need = target_k - len(sampled_indices)
    extra = farthest_point_sampling_with_seeds(dinov2_feats, sampled_indices, need)
    all_indices = sorted(set(sampled_indices + extra))
    final = sorted(set(frame_map[i] for i in all_indices))[:target_k]
    return final

def farthest_point_sampling_with_seeds(features, seed_indices, num_new):
    """Farthest Point Sampling in feature space with initial seeds"""
    N = features.shape[0]
    norm_feats = features / (np.linalg.norm(features, axis=1, keepdims=True) + 1e-8)
    selected = seed_indices.copy()
    for _ in range(num_new):
        candidates = [i for i in range(N) if i not in selected]
        if not candidates: break
        dists = []
        for c in candidates:
            min_dist = np.min(1 - np.dot(norm_feats[c], norm_feats[selected].T))
            dists.append(min_dist)
        next_idx = candidates[int(np.argmax(dists))]
        selected.append(next_idx)
    return [i for i in selected if i not in seed_indices]

def load_features_and_scores(data_dir, video_id):
    """Load pre-extracted DINOv2 features and ITM scores"""
    video_dir = os.path.join(data_dir, video_id)
    dino_path = os.path.join(video_dir, 'dinov2_features.pkl')
    scores_path = os.path.join(video_dir, 'similarity_scores.json')
    if not os.path.exists(dino_path) or not os.path.exists(scores_path):
        return None, None
    with open(dino_path, 'rb') as f: dino_feats = pickle.load(f)
    with open(scores_path, 'r', encoding='utf-8') as f: scores = json.load(f)
    return dino_feats, scores

def main():
    parser = argparse.ArgumentParser(description="Keyframe Selection via Scene Segmentation")
    parser.add_argument('--features_dir', type=str, default='./score_and_features')
    
        # --------------- Setting ---------------
    """ 
    !!! replace your dataset path here   https://github.com/MME-Benchmarks/Video-MME
    """
    parser.add_argument('--dataset_path', type=str, default='datasets/videomme/data')
        # --------------- Setting ---------------
        
    parser.add_argument('--questions_file', type=str, default='./videomme_json_file.json')
    parser.add_argument('--max_frames', type=int, default=16)
    parser.add_argument('--window_size', type=int, default=3)
    parser.add_argument('--scene_num', type=int, default=10)
    parser.add_argument('--min_scene_length', type=int, default=5)
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('--output_dir', type=str, default='./')
    args = parser.parse_args()

    scene_analyzer = VideoSceneAnalyzer(args.window_size, args.min_scene_length)
    keyframe_selector = QueryKeyFrameSelector(args.alpha)
    with open(args.questions_file, 'r', encoding='utf-8') as f:
        questions_data = json.load(f)

    os.makedirs(args.output_dir, exist_ok=True)
    out_path = os.path.join(args.output_dir, f"videomme_{args.max_frames}frames_selected_by_efs.json")

    results = []
    for i, item in enumerate(tqdm(questions_data, desc="Processing Videos")):
        item_copy = item.copy()
        video_id = item['video_id']
        video_path = os.path.join(args.dataset_path, item['url'] + ".mp4")
        dino_feats, scores_data = load_features_and_scores(args.features_dir, video_id)
        frame_map = scores_data['frame_indices']

        scenes = scene_analyzer.analyze(dino_feats, args.scene_num)
        for q_idx, question in enumerate(item['questions']):
            if len(dino_feats) < args.max_frames:
                vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
                total = len(vr)
                final_indices = np.linspace(0, total - 1, args.max_frames, dtype=int).tolist()
            else:
                itm = np.array(scores_data['questions'][q_idx]["blip2_similarities"])
                sampled, _, _ = keyframe_selector.select(scenes, dino_feats, itm, args.max_frames)
                final_indices = map_and_ensure_keyframes(sampled, frame_map, args.max_frames, video_path, dino_feats)
            assert len(final_indices) == args.max_frames, f"Frame count mismatch: {len(final_indices)} != {args.max_frames}"
            item_copy['questions'][q_idx]['keyframe_indices'] = final_indices
        results.append(item_copy)
        with open(out_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=4)
    print(f"Results saved to: {out_path}")

if __name__ == "__main__":
    main()